import json
from collections.abc import Callable
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from typing import Any
from urllib.parse import urljoin

import requests
from dateutil import parser

from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.utils.logger import setup_logger

# Fairly generous retry because it's not understood why occasionally GraphQL requests fail even with timeout > 1 min
SLAB_GRAPHQL_MAX_TRIES = 10
SLAB_API_URL = "https://api.slab.com/v1/graphql"
logger = setup_logger()


def run_graphql_request(
    graphql_query: dict, bot_token: str, max_tries: int = SLAB_GRAPHQL_MAX_TRIES
) -> str:
    headers = {"Authorization": bot_token, "Content-Type": "application/json"}

    for try_count in range(max_tries):
        try:
            response = requests.post(
                SLAB_API_URL, headers=headers, json=graphql_query, timeout=60
            )
            response.raise_for_status()

            if response.status_code != 200:
                raise ValueError(f"GraphQL query failed: {graphql_query}")

            return response.text

        except (requests.exceptions.Timeout, ValueError) as e:
            if try_count < max_tries - 1:
                logger.warning("A Slab GraphQL error occurred. Retrying...")
                continue

            if isinstance(e, requests.exceptions.Timeout):
                raise TimeoutError("Slab API timed out after 3 attempts")
            else:
                raise ValueError("Slab GraphQL query failed after 3 attempts")

    raise RuntimeError(
        "Unexpected execution from Slab Connector. This should not happen."
    )  # for static checker


def get_all_post_ids(bot_token: str) -> list[str]:
    query = """
        query GetAllPostIds {
            organization {
                posts {
                    id
                }
            }
        }
        """

    graphql_query = {"query": query}

    results = json.loads(run_graphql_request(graphql_query, bot_token))
    posts = results["data"]["organization"]["posts"]
    return [post["id"] for post in posts]


def get_post_by_id(post_id: str, bot_token: str) -> dict[str, str]:
    query = """
        query GetPostById($postId: ID!) {
            post(id: $postId) {
                title
                content
                linkAccess
                updatedAt
            }
        }
        """
    graphql_query = {"query": query, "variables": {"postId": post_id}}
    results = json.loads(run_graphql_request(graphql_query, bot_token))
    return results["data"]["post"]


def iterate_post_batches(
    batch_size: int, bot_token: str
) -> Generator[list[dict[str, str]], None, None]:
    """This may not be safe to use, not sure if page edits will change the order of results"""
    query = """
        query IteratePostBatches($query: String!, $first: Int, $types: [SearchType], $after: String) {
            search(query: $query, first: $first, types: $types, after: $after) {
                edges {
                    node {
                        ... on PostSearchResult {
                            post {
                                id
                                title
                                content
                                updatedAt
                            }
                        }
                    }
                }
                pageInfo {
                    endCursor
                    hasNextPage
                }
            }
        }
    """
    pagination_start = None
    exists_more_pages = True
    while exists_more_pages:
        graphql_query = {
            "query": query,
            "variables": {
                "query": "",
                "first": batch_size,
                "types": ["POST"],
                "after": pagination_start,
            },
        }
        results = json.loads(run_graphql_request(graphql_query, bot_token))
        pagination_start = results["data"]["search"]["pageInfo"]["endCursor"]
        hits = results["data"]["search"]["edges"]

        posts = [hit["node"] for hit in hits]
        if posts:
            yield posts

        exists_more_pages = results["data"]["search"]["pageInfo"]["hasNextPage"]


def get_slab_url_from_title_id(base_url: str, title: str, page_id: str) -> str:
    """This is not a documented approach but seems to be the way it works currently
    May be subject to change without notification"""
    title = (
        title.replace("[", "")
        .replace("]", "")
        .replace(":", "")
        .replace(" ", "-")
        .lower()
    )
    url_id = title + "-" + page_id
    return urljoin(urljoin(base_url, "posts/"), url_id)


class SlabConnector(LoadConnector, PollConnector):
    def __init__(
        self,
        base_url: str,
        batch_size: int = INDEX_BATCH_SIZE,
        slab_bot_token: str | None = None,
    ) -> None:
        self.base_url = base_url
        self.batch_size = batch_size
        self.slab_bot_token = slab_bot_token

    def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
        self.slab_bot_token = credentials["slab_bot_token"]
        return None

    def _iterate_posts(
        self, time_filter: Callable[[datetime], bool] | None = None
    ) -> GenerateDocumentsOutput:
        doc_batch: list[Document] = []

        if self.slab_bot_token is None:
            raise ConnectorMissingCredentialError("Slab")

        all_post_ids: list[str] = get_all_post_ids(self.slab_bot_token)

        for post_id in all_post_ids:
            post = get_post_by_id(post_id, self.slab_bot_token)
            last_modified = parser.parse(post["updatedAt"])
            if time_filter is not None and not time_filter(last_modified):
                continue

            page_url = get_slab_url_from_title_id(self.base_url, post["title"], post_id)

            content_text = ""
            contents = json.loads(post["content"])
            for content_segment in contents:
                insert = content_segment.get("insert")
                if insert and isinstance(insert, str):
                    content_text += insert

            doc_batch.append(
                Document(
                    id=post_id,  # can't be url as this changes with the post title
                    sections=[Section(link=page_url, text=content_text)],
                    source=DocumentSource.SLAB,
                    semantic_identifier=post["title"],
                    metadata={},
                )
            )

            if len(doc_batch) >= self.batch_size:
                yield doc_batch
                doc_batch = []

        if doc_batch:
            yield doc_batch

    def load_from_state(self) -> GenerateDocumentsOutput:
        yield from self._iterate_posts()

    def poll_source(
        self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
    ) -> GenerateDocumentsOutput:
        start_time = datetime.fromtimestamp(start, tz=timezone.utc)
        end_time = datetime.fromtimestamp(end, tz=timezone.utc)

        yield from self._iterate_posts(
            time_filter=lambda t: start_time <= t <= end_time
        )
